{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "from sklearn.utils import shuffle\n",
    "from joblib import Parallel,delayed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(33465, 6993)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_nodup = joblib.load('../../scrum_data/train_data/train_nodup.lz4')\n",
    "data_tag = joblib.load('../../scrum_data/train_data/train_tag.lz4')\n",
    "data_date = joblib.load('../../scrum_data/train_data/train_date.lz4')\n",
    "data_week =joblib.load('../../scrum_data/train_data/train_week.lz4')\n",
    "data_null_sign = joblib.load('../../scrum_data/train_data/train_null_sign90.lz4')\n",
    "data_majcnt = joblib.load('../../scrum_data/train_data/train_majcnt.lz4')\n",
    "data_cf_fs = joblib.load('../../scrum_data/train_data/train_cf_86_fs10.lz4')\n",
    "\n",
    "data = pd.concat([data_nodup,data_tag,data_date,data_week,data_null_sign,data_majcnt,data_cf_fs],axis=1,copy=False)\n",
    "data_label = pd.read_csv('../../preprocess_data/train_y_33465.csv',usecols=['label'])\n",
    "x_df = data.fillna(-1)\n",
    "y = data_label.values.ravel()\n",
    "x_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_by_tag(df,y):\n",
    "    df['label'] = y\n",
    "    df_0 = df[df['tag']==0]\n",
    "    df_1 = df[df['tag']==1]\n",
    "    return df_0,df_1\n",
    "\n",
    "def gen_one_fold(df_tmp_0,df_1,i):\n",
    "        df_tmp = pd.concat([df_tmp_0,df_1],axis=0)\n",
    "        x_df_tmp = df_tmp.drop(columns=['label'])\n",
    "        y_tmp = df_tmp['label'].values.ravel()\n",
    "        joblib.dump(x_df_tmp,'./train_data/x_df_fold{}.lz4'.format(i),compress='lz4')\n",
    "        joblib.dump(y_tmp,'./train_data/y_fold{}.lz4'.format(i),compress='lz4')\n",
    "        \n",
    "def gen_n_fold(x_df,y,nfold=10):\n",
    "    '''对tag=0的部分进行不放回抽样；tag=1的部分全部保留'''\n",
    "    if nfold == 1:\n",
    "        joblib.dump(x_df,'./train_data/x_df_fold{}.lz4'.format(0),compress='lz4')\n",
    "        joblib.dump(y,'./train_data/y_fold{}.lz4'.format(0),compress='lz4')\n",
    "        return\n",
    "    \n",
    "    df_0,df_1= split_by_tag(x_df,y)\n",
    "    \n",
    "    # shuffle\n",
    "    df_0 = shuffle(df_0,random_state =2018)\n",
    "    \n",
    "    # gen folds\n",
    "    \n",
    "    seg_len = int(len(df_0)/nfold)\n",
    "    seg_list = [seg_pos for seg_pos in range(seg_len,len(df_0),seg_len)]\n",
    "    if len(df_0)-seg_list[-1]>0.2*len(df_0):\n",
    "        seg_list = seg_list+[len(df_0)]\n",
    "    fold_list = Parallel(n_jobs=18, verbose=10)(delayed(gen_one_fold)(df_0.iloc[seg_pos-seg_len:seg_pos],df_1,i) \n",
    "                                            for i,seg_pos in enumerate(seg_list))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.94 s, sys: 792 ms, total: 6.73 s\n",
      "Wall time: 14 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "gen_n_fold(x_df,y,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = joblib.load('./train_data/x_df_fold0.lz4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
